{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using `matplotlib` to display inline images\n",
    "\n",
    "In this notebook we will explore using `matplotlib` to display images in our notebooks, and work towards developing a reusable function to display 2D,3D, color, and label overlays for SimpleITK images.\n",
    "\n",
    "We will also look at the subtleties of working with image filters that require the input images' to be overlapping. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import SimpleITK as sitk\n",
    "# Download data to work on\n",
    "%run update_path_to_download_script\n",
    "from downloaddata import fetch_data as fdata"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "SimpleITK has a built in `Show` method which saves the image to disk and launches a user configurable program ( defaults to ImageJ ), to display the image. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "simpleitk_error_allowed": "Exception thrown in SimpleITK Show:"
   },
   "outputs": [],
   "source": [
    "img1 = sitk.ReadImage(fdata(\"cthead1.png\"))\n",
    "sitk.Show(img1, title=\"cthead1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "simpleitk_error_allowed": "Exception thrown in SimpleITK Show:"
   },
   "outputs": [],
   "source": [
    "img2 = sitk.ReadImage(fdata(\"VM1111Shrink-RGB.png\"))\n",
    "sitk.Show(img2, title=\"Visible Human Head\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nda = sitk.GetArrayViewFromImage(img1)\n",
    "plt.imshow(nda)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nda = sitk.GetArrayViewFromImage(img2)\n",
    "ax = plt.imshow(nda)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def myshow(img):\n",
    "    nda = sitk.GetArrayViewFromImage(img)\n",
    "    plt.imshow(nda)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "myshow(img2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "myshow(sitk.Expand(img2, [10]*5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This image does not appear bigger.\n",
    "\n",
    "There are numerous improvements that we can make:\n",
    "\n",
    " - support 3d images\n",
    " - include a title\n",
    " - use physical pixel size for axis labels\n",
    " - show the image as gray values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def myshow(img, title=None, margin=0.05, dpi=80):\n",
    "    nda = sitk.GetArrayViewFromImage(img)\n",
    "    spacing = img.GetSpacing()\n",
    "        \n",
    "    if nda.ndim == 3:\n",
    "        # fastest dim, either component or x\n",
    "        c = nda.shape[-1]\n",
    "        \n",
    "        # the the number of components is 3 or 4 consider it an RGB image\n",
    "        if not c in (3,4):\n",
    "            nda = nda[nda.shape[0]//2,:,:]\n",
    "    \n",
    "    elif nda.ndim == 4:\n",
    "        c = nda.shape[-1]\n",
    "        \n",
    "        if not c in (3,4):\n",
    "            raise Runtime(\"Unable to show 3D-vector Image\")\n",
    "            \n",
    "        # take a z-slice\n",
    "        nda = nda[nda.shape[0]//2,:,:,:]\n",
    "            \n",
    "    ysize = nda.shape[0]\n",
    "    xsize = nda.shape[1]\n",
    "      \n",
    "    # Make a figure big enough to accommodate an axis of xpixels by ypixels\n",
    "    # as well as the ticklabels, etc...\n",
    "    figsize = (1 + margin) * ysize / dpi, (1 + margin) * xsize / dpi\n",
    "\n",
    "    fig = plt.figure(figsize=figsize, dpi=dpi)\n",
    "    # Make the axis the right size...\n",
    "    ax = fig.add_axes([margin, margin, 1 - 2*margin, 1 - 2*margin])\n",
    "    \n",
    "    extent = (0, xsize*spacing[1], ysize*spacing[0], 0)\n",
    "    \n",
    "    t = ax.imshow(nda,extent=extent,interpolation=None)\n",
    "    \n",
    "    if nda.ndim == 2:\n",
    "        t.set_cmap(\"gray\")\n",
    "    \n",
    "    if(title):\n",
    "        plt.title(title)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "myshow(sitk.Expand(img2,[2,2]), title=\"Big Visibile Human Head\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tips and Tricks for Visualizing Segmentations\n",
    "\n",
    "We start by loading a segmented image. As the segmentation is just an image with integral data, we can display the labels as we would any other image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img1_seg = sitk.ReadImage(fdata(\"2th_cthead1.png\"))\n",
    "myshow(img1_seg, \"Label Image as Grayscale\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also map the scalar label image to a color image as shown below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "myshow(sitk.LabelToRGB(img1_seg), title=\"Label Image as RGB\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Most filters which take multiple images as arguments require that the images occupy the same physical space. That is the pixel you are operating must refer to the same location. Luckily for us our image and labels do occupy the same physical space, allowing us to overlay the segmentation onto the original image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "myshow(sitk.LabelOverlay(img1, img1_seg), title=\"Label Overlayed\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also overlay the labels as contours."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "myshow(sitk.LabelOverlay(img1, sitk.LabelContour(img1_seg), 1.0))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tips and Tricks for 3D Image Visualization\n",
    "\n",
    "Now lets move on to visualizing real MRI images with segmentations. The Surgical Planning Laboratory at Brigham and Women's Hospital provides a wonderful Multi-modality MRI-based Atlas of the Brain that we can use.\n",
    "\n",
    "Please note, what is done here is for convenience and is not the common way images are displayed for radiological work."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_T1 = sitk.ReadImage(fdata(\"nac-hncma-atlas2013-Slicer4Version/Data/A1_grayT1.nrrd\"))\n",
    "img_T2 = sitk.ReadImage(fdata(\"nac-hncma-atlas2013-Slicer4Version/Data/A1_grayT2.nrrd\"))\n",
    "img_labels = sitk.ReadImage(fdata(\"nac-hncma-atlas2013-Slicer4Version/Data/hncma-atlas.nrrd\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "myshow(img_T1)\n",
    "myshow(img_T2)\n",
    "myshow(sitk.LabelToRGB(img_labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "size = img_T1.GetSize()\n",
    "myshow(img_T1[:,size[1]//2,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "slices =[img_T1[size[0]//2,:,:], img_T1[:,size[1]//2,:], img_T1[:,:,size[2]//2]]\n",
    "myshow(sitk.Tile(slices, [3,1]), dpi=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nslices = 5\n",
    "slices = [ img_T1[:,:,s] for s in range(0, size[2], size[0]//(nslices+1))]\n",
    "myshow(sitk.Tile(slices, [1,0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's create a version of the show methods which allows the selection of slices to be displayed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def myshow3d(img, xslices=[], yslices=[], zslices=[], title=None, margin=0.05, dpi=80):\n",
    "    size = img.GetSize()\n",
    "    img_xslices = [img[s,:,:] for s in xslices]\n",
    "    img_yslices = [img[:,s,:] for s in yslices]\n",
    "    img_zslices = [img[:,:,s] for s in zslices]\n",
    "    \n",
    "    maxlen = max(len(img_xslices), len(img_yslices), len(img_zslices))\n",
    "    \n",
    "        \n",
    "    img_null = sitk.Image([0,0], img.GetPixelID(), img.GetNumberOfComponentsPerPixel())\n",
    "    \n",
    "    img_slices = []\n",
    "    d = 0\n",
    "    \n",
    "    if len(img_xslices):\n",
    "        img_slices += img_xslices + [img_null]*(maxlen-len(img_xslices))\n",
    "        d += 1\n",
    "        \n",
    "    if len(img_yslices):\n",
    "        img_slices += img_yslices + [img_null]*(maxlen-len(img_yslices))\n",
    "        d += 1\n",
    "     \n",
    "    if len(img_zslices):\n",
    "        img_slices += img_zslices + [img_null]*(maxlen-len(img_zslices))\n",
    "        d +=1\n",
    "    \n",
    "    if maxlen != 0:\n",
    "        if img.GetNumberOfComponentsPerPixel() == 1:\n",
    "            img = sitk.Tile(img_slices, [maxlen,d])\n",
    "        #TO DO check in code to get Tile Filter working with vector images\n",
    "        else:\n",
    "            img_comps = []\n",
    "            for i in range(0,img.GetNumberOfComponentsPerPixel()):\n",
    "                img_slices_c = [sitk.VectorIndexSelectionCast(s, i) for s in img_slices]\n",
    "                img_comps.append(sitk.Tile(img_slices_c, [maxlen,d]))\n",
    "            img = sitk.Compose(img_comps)\n",
    "            \n",
    "    \n",
    "    myshow(img, title, margin, dpi)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "myshow3d(img_T1,yslices=range(50,size[1]-50,20), zslices=range(50,size[2]-50,20), dpi=30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "myshow3d(img_T2,yslices=range(50,size[1]-50,30), zslices=range(50,size[2]-50,20), dpi=30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "myshow3d(sitk.LabelToRGB(img_labels),yslices=range(50,size[1]-50,20), zslices=range(50,size[2]-50,20), dpi=30)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We next visualize the T1 image with an overlay of the labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "simpleitk_error_expected": "Both images for LabelOverlayImageFilter don't match type or dimension!"
   },
   "outputs": [],
   "source": [
    "# Why doesn't this work? The images do overlap in physical space.\n",
    "myshow3d(sitk.LabelOverlay(img_T1,img_labels),yslices=range(50,size[1]-50,20), zslices=range(50,size[2]-50,20), dpi=30)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Two ways to solve our problem: (1) resample the labels onto the image grid (2) resample the image onto the label grid. The difference between the two from a computation standpoint depends on the grid sizes and on the interpolator used to estimate values at non-grid locations. \n",
    "\n",
    "Note interpolating a label image with an interpolator that can generate non-label values is problematic as you may end up with an image that has more classes/labels than your original. This is why we only use the nearest neighbor interpolator when working with label images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Option 1: Resample the label image using the identity transformation\n",
    "resampled_img_labels = sitk.Resample(img_labels, img_T1, sitk.Transform(), sitk.sitkNearestNeighbor,\n",
    "                                     0.0, img_labels.GetPixelID())\n",
    "# Overlay onto the T1 image, requires us to rescale the intensity of the T1 image to [0,255] and cast it so that it can \n",
    "# be combined with the color overlay (we use an alpha blending of 0.5).\n",
    "myshow3d(sitk.LabelOverlay(sitk.Cast(sitk.RescaleIntensity(img_T1), sitk.sitkUInt8),resampled_img_labels, 0.5),\n",
    "         yslices=range(50,size[1]-50,20), zslices=range(50,size[2]-50,20), dpi=30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Option 2: Resample the T1 image using the identity transformation\n",
    "resampled_T1 = sitk.Resample(img_T1, img_labels, sitk.Transform(), sitk.sitkLinear,\n",
    "                             0.0, img_T1.GetPixelID())\n",
    "# Overlay onto the T1 image, requires us to rescale the intensity of the T1 image to [0,255] and cast it so that it can \n",
    "# be combined with the color overlay (we use an alpha blending of 0.5).\n",
    "myshow3d(sitk.LabelOverlay(sitk.Cast(sitk.RescaleIntensity(resampled_T1), sitk.sitkUInt8),img_labels, 0.5),\n",
    "         yslices=range(50,size[1]-50,20), zslices=range(50,size[2]-50,20), dpi=30)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Why are the two displays above different? (hint: in the calls to the \"myshow3d\" function the indexes of the y and z slices are the same). "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The ``myshow`` and ``myshow3d`` functions are really useful. They have been copied into a \"myshow.py\" file so that they can be imported into other notebooks."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
